#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 28 16:56:54 2021
Comparing GMOS and Chacaltaya Hg0 measurements with model run
Sites are Manaus, Nieuw Nickerie, Chacaltaya, Bariloche
@author: arifeinberg
"""

#%%
#import os
#os.chdir('/Users/arifeinberg/target2/fs03/d0/arifein/python/')

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from functions_Hg0 import get_mod_data, ds_sel_yr, annual_avg
import cartopy.crs as ccrs
import cartopy.feature as cf
from matplotlib import colors
import sys
#%% Important functions
def SA_Passive_sampler(j_GEM_surf):
    """Get model mean and observation mean at passive sampler sites
    
    Parameters
    ----------
    j_GEM_surf : xarray dataarray
         Model datasets for Hg0 at surface
    
    """

    source = '../obs_datasets/GEM/Passive_sampler_toronto.csv'
    data_SA = pd.read_csv(source)
    obs_type = data_SA['Type']
    
    data_SA_filt = data_SA.loc[~(data_SA['Type']=='Urban') & ~(data_SA['Type']=='Mining impacted')
                                & ~(data_SA['Type']=='Contaminated')]
    obs_Hg0_fil = data_SA_filt['Measured Conc (ng/m3)'].values
    obs_lat_fil = data_SA_filt['Latitude'].values
    obs_lon_fil = data_SA_filt['Longitude'].values
    
    obs_mean = np.mean(obs_Hg0_fil)
    obs_std = np.std(obs_Hg0_fil)
    
    model_sites = get_mod_data(j_GEM_surf, obs_lat_fil, obs_lon_fil)
    model_avg = np.mean(model_sites)
    
    return obs_mean, obs_std, model_avg

def SA_GMOS(ds_list, ds_list_CHC, ds_names, ds_num, Year = None):
    """Plot observations from GMOS against model simulations
    
    Parameters
    ----------
    ds_list : list of xarray datasets
         Model datasets for Hg0 on model levels
    ds_list_CHC : list of xarray datasets 
         Model datasets for Chacaltaya
    ds_names : list of strings
        Names of model simulations, for plotting 
    ds_num : list of strings
        Numbers of model simulations, for vars 
    
    Year : list of int, optional
        Optional parameter to only select one year from each simulation    
    
    """
    Obs_names = ['Manaus', 'Nieuw Nickerie','Bariloche', 'Chacaltaya', 'Chacaltaya-ENSO', 'Passive Samplers']
    Obs_names_plot = ['Manaus', 'Suriname','Bariloche', 'Chacaltaya', 'LAPAN']
#    Obs_values = [1.04, 1.21, 0.88, 0.89, 1.34] # Helene
    Obs_values = [1.03, 1.17, 0.86, 0.89, 1.34] # calculated by me
    Obs_std = [0.16,0.21,0.13,0.14,0.24] # calculated by me
    Obs_lat = [-2.89,5.96,-41.13,-16.35]
    Obs_lon = [-59.97,-57.04,-71.42,-68.13]
    
    # Figure axes for maps
    f1,  axes1 = plt.subplots(1, 1, figsize=[11,5], 
                              gridspec_kw=dict(hspace=0.2, wspace=0.0))
    plt.subplots_adjust(top=0.8)
    n_obs = len(Obs_names)-1
    n_sim = len(ds_num)
    GEM_vals_sim = np.zeros((n_sim, n_obs))
    
    # Loop over the model simulation datasts
    for j, jds in enumerate(ds_list):
        # Allow subsetting for years of the simulation, if inputted into the function
       
        j_GEM_surf = ds_sel_yr(jds, 'SpeciesConc_Hg0', Year[j]).isel(lev=0)
        j_RGM_surf = ds_sel_yr(jds, 'SpeciesConc_Hg2', Year[j]).isel(lev=0)

        j_TGM_surf = j_GEM_surf + j_RGM_surf
        
        model_lat = jds.lat
        model_lon = jds.lon

        # Make a variable for the unit conversion factor from vmr to  ng/m^3
        # Now more traceable
        R = 8.314462 # m^3 Pa K^-1 mol ^-1
        MW_Hg = 200.59 # g mol^-1
        ng_g = 1e9 # ng/g
        
        stdpressure = 101325 # Pascals
        stdtemp = 273.15 # Kelvins
        
        unit_conv = stdpressure / R / stdtemp * MW_Hg * ng_g # converter from vmr to ng m^-3

        
        j_GEM_surf = j_GEM_surf * unit_conv # now in ng m^-3

        # find model values at coordinates of observations
        j_GEM_site = get_mod_data(j_GEM_surf, Obs_lat, Obs_lon)
        # for NIK use TGM
        j_GEM_site[1] = get_mod_data(j_TGM_surf * unit_conv, Obs_lat[1], Obs_lon[1])
        
        # # for Chacaltaya, use the interpolated file
        j_GEM_CHC = ds_sel_yr(ds_list_CHC[j], 'SpeciesConc_Hg0', Year[j]).mean() * unit_conv
        #j_RGM_CHC = ds_sel_yr(ds_list_CHC[j], 'SpeciesConc_Hg2', Year[j]).mean() * unit_conv
        
        j_GEM_site[3] = j_GEM_CHC  #replace surface value with value at 5240 m
        
        # Get values from passive samplers
        PS_mean, PS_std, PS_model = SA_Passive_sampler(j_GEM_surf)

        # Save simulation values
        GEM_vals_sim[j,:-1] = j_GEM_site.values
        GEM_vals_sim[j,-1] = PS_model # add Passive sampler value
        
    # append Passive Sampler observations to numpy arrays
    Obs_values.append(PS_mean)  
    Obs_std.append(PS_std)
    print("NEWCHEM")
    print(GEM_vals_sim[4,:])
    # Plot values at different observation sites
    midpoints = np.zeros(n_obs) # for xticks
    for i in range(n_obs):
        list_colours = ['#e41a1c', '#377eb8', '#b2df8a', '#33a02c', '#984ea3', \
                        '#beaed4', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', \
                        '#a6cee3', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', \
                        '#e31a1c', '#fdbf6f', '#ff7f00', '#cab2d6', '#e41a1c', \
                        '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', \
                        '#a65628']
        x_values = np.linspace(2*i + 1.5/n_sim, 2*i + 1 - 2/n_sim, n_sim) # for spacing plot
        diff_x = (x_values[1] - x_values[0])*2. # for spacing plot
        
        midpoints[i] = x_values[3]
        
        for j in range(n_sim):
            if (i==0):
                axes1.scatter(x_values[j],GEM_vals_sim[j,i], s=100,
                     c=list_colours[j], marker='d',  edgecolor = 'k', label=ds_names[j])
            else:
                axes1.scatter(x_values[j],GEM_vals_sim[j,i], s=100,
                     c=list_colours[j], marker='d',  edgecolor = 'k')

        if (i<3):
            axes1.errorbar(x_values[-1] + diff_x,Obs_values[i], yerr=Obs_std[i],
                           fmt='o',color='k',markerfacecolor='k', lw=2, ms=9)            
        if (i==3):
            axes1.errorbar(x_values[-1] + diff_x,Obs_values[i], yerr=Obs_std[i],
                           fmt='o',color='k',markerfacecolor='k', lw=2, ms=9, label="Obs")
            axes1.errorbar(x_values[-1] + diff_x,Obs_values[i+1], yerr=Obs_std[i+1],
                           fmt='^',color='silver',markerfacecolor='silver', lw=2, ms=9, label="Obs (ENSO)")
            handles, labels = axes1.get_legend_handles_labels()
        if (i==4):
            axes1.errorbar(x_values[-1] + diff_x,Obs_values[i+1], yerr=Obs_std[i+1],
                           fmt='o',color='k',markerfacecolor='k', lw=2, ms=9) 
# e1=axes1.errorbar(litt_obs_med, litt_mod_med, yerr=error_litt_mod, xerr=error_litt_obs,
#             fmt='*', color='k', markerfacecolor='w', lw=2,
#             mew = 1.5, ms =14,
#             label='Litterfall median (IQR)')

        axes1.set_ylabel('Atmospheric Hg (ng m$^{-3}$)', fontsize=18)
        axes1.set_xticks(midpoints)
        axes1.set_xticklabels(Obs_names_plot, fontsize=18)
        axes1.tick_params(axis='y', labelsize=16 ) 
        #axes1.set_title(Obs_names[i])
        # axes1.tick_params(
        #     axis='x',          # changes apply to the x-axis
        #     which='both',      # both major and minor ticks are affected
        #     bottom=False,      # ticks along the bottom edge are off
        #     top=False,         # ticks along the top edge are off
        #     labelbottom=False) # labels along the bottom edge are off
            
    f1.legend(handles, labels, loc='upper center', ncol=4, columnspacing = 0., fontsize=17)

    return f1


#%% load model data and run plotting function

runs = ['0005','0017','0018','0019','0102','0105']
fns = [None] * len(runs)
year_to_analyze = [None] * len(runs)
ds_list = [None] * len(runs)
fns_CHC = [None] * len(runs)
ds_list_CHC = [None] * len(runs)

for i in range(len(runs)):
    fns[i] =  '../GEOS-Chem_runs/run' + runs[i] + '/OutputDir/GEOSChem.SpeciesConc.alltime_m.nc4'
    year_to_analyze[i] = 2015 # year to analyze from simulations
    ds_list[i] = xr.open_dataset(fns[i])
    fns_CHC[i] =  '../GEOS-Chem_runs/run' + runs[i] + '/OutputDir/CHC_5240_Hg.nc4'
    ds_list_CHC[i] = xr.open_dataset(fns_CHC[i])

ds_names = ['BASE', 'OBRIST_R','AMAZON_L', 'AMAZON_U', 'NEWCHEM','NEWCHEM_D']
figure_sites = SA_GMOS(ds_list, ds_list_CHC, ds_names, runs, year_to_analyze)

figure_sites.savefig('Figures/Fig4_Hg0_SA_GMOS_v2.pdf',bbox_inches = 'tight')
